"""Calculate the anat-to-template transformation using ANTs.

Deoblique subject anat and convert to NIFTI.
Align center of probabilistic map for skull stripping to the subject anat.
Create skull-stripped anatomical in subject native space.
Calculate non-linear alignment from skull-stripped subject anat to template.
Calculate linear alignment from skull-stripped subject anat to EPI.

The root directory must contain:
    This script
    Subdirectories named by subject ID with assumed filenames
    Anatomical template "TT_N27_SurfVol.nii"
    OASIS challenge data in "MICCAI2012-Multi-Atlas-Challenge-Data"

How to use it:
    python beau_anat_to_template.py {subject_id}

Important outputs:
    Skull-stripped brain
        {subject_id}/{subject_id}.BrainExtractionBrain.nii.gz
    Skull-stripped brain non-linearly warped to the template
        {subject_id}/{subject_id}.Warped.nii.gz
    Linear anatomical to EPI alignment
        {subject_id}/{subject_id}.BrainExtractionBrain_al_mat.aff12.1D
    Linear anatomical to template alignment
        {subject_id}/{subject_id}.0GenericAffine.mat
    Non-linear anatomical to template alignment
        {subject_id}/{subject_id}.1Warp.nii.gz
"""

import glob
import os
import subprocess
import sys


def deoblique(anat_path, output_path):
    """Deoblique."""
    cmd = "3dWarp -deoblique -prefix {0} {1}".format(output_path, anat_path)
    subprocess.call(cmd, shell=True)


def align_centers(fixed_path, moving_path, output_prefix, output_dir,
                  atlas_data_path):
    print("Aligning centers.")
    print("Fixed: {0}".format(fixed_path))
    print("Moving: {0}".format(moving_path))
    output_filename = output_prefix + '.nii'
    print("output_filename: {0}".format(output_filename))
    cmd = "@Align_Centers -base {0} -dset {1} -prefix {2}".format(
        fixed_path, moving_path, output_filename
    )
    subprocess.call(cmd, shell=True)

    # Handle AFNI path weirdness
    print("Handling AFNI path weirdness.")
    temp_output_path = os.path.join(atlas_data_path, output_filename)
    if not os.path.exists(temp_output_path):
        print("temp_output_path {0} does not exist. Adding .gz extension".format(temp_output_path))
        temp_output_path += '.gz'
        output_filename += '.gz'
    output_path = os.path.join(output_dir, output_filename)
    print("output_path: {0}".format(output_path))
    print("Renaming {0} to {1}".format(temp_output_path, output_path))
    os.rename(
        temp_output_path,
        output_path
    )
    oned_start = output_prefix + '.1D'
    oned_end = os.path.join(
            output_dir, output_prefix + '.1D'
        )
    print("Renaming {0} to {1}".format(oned_start, oned_end))
    os.rename(oned_start, oned_end)
    return output_path


def extract_brain(anat_path, template_path, prob_al_path, mask_al_path,
                  output_path):
    """Extract brain using ANTs."""
    print("Extracting brain and creating anatomical brain mask...")
    cmd = [
        'antsBrainExtraction.sh',
        '-d', '3',
        '-a', anat_path,
        '-e', template_path,
        '-m', prob_al_path,
        '-f', mask_al_path,
        '-o', output_path
    ]
    subprocess.call(cmd)


def anat_to_template(anat_path, template_path, output_prefix, quick=False,
                     n_threads=8):
    """Calculate the anat-to-template transform using ANTs."""
    cmd = [
        'antsRegistrationSyN.sh',
        '-d', '3',
        '-f', template_path,
        '-m', anat_path,
        '-o', output_prefix,
        '-n', str(n_threads)
    ]
    if quick:
        cmd[0] = 'antsRegistrationSyNQuick.sh'
    subprocess.call(cmd)


def anat_to_epi(anat_path, epi_path):
    cmd = [
        'align_epi_anat.py',
        '-epi', epi_path,
        '-anat', anat_path,
        '-volreg', 'on',
        '-tshift', 'off',
        '-epi_base', '0',
        '-anat_has_skull', 'no',
        '-deoblique', 'off',
        #'-ginormous_move'
        #'-align_centers', 'yes',
        #'-giant_move'
    ]
    subprocess.call(cmd)


def gzfix(path):
    print("Compensating for AFNI auto-gzip...")
    if not os.path.exists(path):
        print("Fixing path: {0}".format(path))
        path = path + '.gz'
    return path

def set_metadata(warped_path):
    cmd = "3drefit -view tlrc -space TLRC {0}".format(warped_path)
    subprocess.call(cmd, shell=True)

if __name__ == "__main__":
    subj_id = sys.argv[1]
    quick = False
    run_index = 2

    anat_path = os.path.join(
        subj_id,
        "{0}.anat+orig".format(subj_id)
    )
    template_path = "TT_N27_SurfVol.nii"
    atlas_data_path = 'MICCAI2012-Multi-Atlas-Challenge-Data'

    strip_template_path = os.path.join(
        atlas_data_path,
        'T_template0.nii.gz'
    )
    strip_prob_path = os.path.join(
        atlas_data_path,
        'T_template0_BrainCerebellumProbabilityMask.nii.gz'
    )
    strip_mask_path = os.path.join(
        atlas_data_path,
        'T_template0_BrainCerebellumRegistrationMask.nii.gz'
    )

    print("Deobliquing anatomical...")
    deoblique_path = os.path.join(
        subj_id,
        "{0}.anat_deoblique.nii".format(subj_id)
    )
    deoblique(
        anat_path,
        deoblique_path
    )
    deoblique_path = gzfix(deoblique_path)

    print("Aligning center of skull stripping probability map...")
    strip_prob_al_prefix = "{0}.skullstrip_prob_al".format(subj_id)
    strip_prob_al_dir = subj_id
    strip_prob_al_path = align_centers(
        deoblique_path,
        strip_prob_path,
        strip_prob_al_prefix,
        strip_prob_al_dir,
        atlas_data_path
    )

    print("Aligning center of skull stripping mask map...")
    strip_mask_al_prefix = "{0}.skullstrip_mask_al".format(subj_id)
    strip_mask_al_dir = subj_id
    strip_mask_al_path = align_centers(
        deoblique_path,
        strip_mask_path,
        strip_mask_al_prefix,
        strip_mask_al_dir,
        atlas_data_path
    )

    print("Aligning center of skull stripping template map...")
    strip_template_al_prefix = "{0}.skullstrip_template_al".format(subj_id)
    strip_template_al_dir = subj_id
    strip_template_al_path = align_centers(
        deoblique_path,
        strip_template_path,
        strip_template_al_prefix,
        strip_template_al_dir,
        atlas_data_path
    )

    print("Skull stripping...")
    brain_path = os.path.join(
        subj_id,
        "{0}.BrainExtractionBrain.nii.gz".format(subj_id)
    )
    if not os.path.exists(brain_path):
        extract_brain(
            deoblique_path,
            strip_template_al_path,
            strip_prob_al_path,
            strip_mask_al_path,
            os.path.join(subj_id, "{0}.".format(subj_id))
        )

    print("Calculating non-linear alignment of anatomical to template...")
    warped_path = os.path.join(
        subj_id,
        "{0}.Warped.nii.gz".format(subj_id)
    )
    if not os.path.exists(warped_path):
        anat_to_template(
            brain_path,
            template_path,
            os.path.join(subj_id, "{0}.".format(subj_id)),
            quick=quick
        )

    print("Calculating linear alignment of anatomical to EPI...")
    test_path = os.path.join(
        subj_id,
        "{0}.BrainExtractionBrain_al_mat.aff12.1D".format(subj_id)
    )
    if not os.path.exists(test_path):
        epi_path = os.path.join(
            subj_id,
            "deoblw_{0}.run{1}+orig".format(subj_id, run_index)
        )
        anat_to_epi(
            brain_path,
            epi_path
        )

    print("Changing header info to TLRC...")
    set_metadata(
        warped_path
    )

    paths = [
        "{0}.BrainExtractionBrain_al_e2a_only_mat.aff12.1D",
        "{0}.BrainExtractionBrain_al_mat.aff12.1D",
        "{0}.BrainExtractionBrain_al+orig.BRIK",
        "{0}.BrainExtractionBrain_al+orig.HEAD",
        "deoblw_{0}.run2_vr_motion.1D"
    ]
    for path in paths:
        if os.path.exists(path.format(subj_id)):
            os.rename(
                path.format(subj_id),
                os.path.join(subj_id, path.format(subj_id))
            )
    delete_paths = glob.glob(os.path.join(subj_id, "__tt_*"))
    for delete_path in delete_paths:
        if os.path.exists(delete_path):
            os.remove(delete_path)

    print("All done.")